{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div align=\"center\">\n",
    "<h2>Basic RAG</h2>\n",
    "</div>\n",
    "\n",
    "<div align=\"center\">\n",
    "    <h3 ><a href=\"https://aiengineering.academy/\" target=\"_blank\">AI Engineering.academy</a></h3>\n",
    "    \n",
    "    \n",
    "</div>\n",
    "\n",
    "<div align=\"center\">\n",
    "<a href=\"https://aiengineering.academy/\" target=\"_blank\">\n",
    "<img src=\"https://raw.githubusercontent.com/adithya-s-k/AI-Engineering.academy/main/assets/banner.png\" alt=\"Ai Engineering. Academy\" width=\"50%\">\n",
    "</a>\n",
    "</div>\n",
    "\n",
    "\n",
    "<div align=\"center\">\n",
    "\n",
    "[![GitHub Stars](https://img.shields.io/github/stars/adithya-s-k/AI-Engineering.academy?style=social)](https://github.com/adithya-s-k/AI-Engineering.academy/stargazers)\n",
    "[![GitHub Forks](https://img.shields.io/github/forks/adithya-s-k/AI-Engineering.academy?style=social)](https://github.com/adithya-s-k/AI-Engineering.academy/network/members)\n",
    "[![GitHub Issues](https://img.shields.io/github/issues/adithya-s-k/AI-Engineering.academy)](https://github.com/adithya-s-k/AI-Engineering.academy/issues)\n",
    "[![GitHub Pull Requests](https://img.shields.io/github/issues-pr/adithya-s-k/AI-Engineering.academy)](https://github.com/adithya-s-k/AI-Engineering.academy/pulls)\n",
    "[![License](https://img.shields.io/github/license/adithya-s-k/AI-Engineering.academy)](https://github.com/adithya-s-k/AI-Engineering.academy/blob/main/LICENSE)\n",
    "\n",
    "</div>\n",
    "\n",
    "## Introduction\n",
    "Retrieval-Augmented Generation (RAG) is a powerful technique that combines the strengths of large language models with the ability to retrieve relevant information from a knowledge base. This approach enhances the quality and accuracy of generated responses by grounding them in specific, retrieved information.\n",
    "\n",
    "This notebook aims to provide a clear and concise introduction to RAG, suitable for beginners who want to understand and implement this technology.\n",
    "\n",
    "### Motivation\n",
    "\n",
    "Traditional language models generate text based on learned patterns from training data. However, when they are presented with queries that require specific, updated, or niche information, they may struggle to provide accurate responses. RAG addresses this limitation by incorporating a retrieval step that provides the language model with relevant context to generate more informed answers.\n",
    "\n",
    "### Method Details\n",
    "\n",
    "#### Document Preprocessing and Vector Store Creation\n",
    "\n",
    "1. **Document Chunking**: The knowledge base documents (e.g., PDFs, articles) are preprocessed and split into manageable chunks. This is done to create a searchable corpus that can be efficiently used in the retrieval process.\n",
    "   \n",
    "2. **Embedding Generation**: Each chunk is converted into a vector representation using pre-trained embeddings (e.g., OpenAI's embeddings). This allows the documents to be stored in a vector database, such as Qdrant, enabling efficient similarity searches.\n",
    "\n",
    "#### Retrieval-Augmented Generation Workflow\n",
    "\n",
    "1. **Query Input**: A user provides a query that needs to be answered.\n",
    "   \n",
    "2. **Retrieval Step**: The query is embedded into a vector using the same embedding model that was used for the documents. A similarity search is then performed in the vector database to find the most relevant document chunks.\n",
    "\n",
    "3. **Generation Step**: The retrieved document chunks are passed to a large language model (e.g., GPT-4) as additional context. The model uses this context to generate a more accurate and relevant response.\n",
    "\n",
    "### Key Features of RAG\n",
    "\n",
    "1. **Contextual Relevance**: By grounding responses in actual retrieved information, RAG models can produce more contextually relevant and accurate answers.\n",
    "   \n",
    "2. **Scalability**: The retrieval step can scale to handle large knowledge bases, allowing the model to draw from vast amounts of information.\n",
    "\n",
    "3. **Flexibility in Use Cases**: RAG can be adapted for a variety of applications, including question answering, summarization, recommendation systems, and more.\n",
    "\n",
    "4. **Improved Accuracy**: Combining generation with retrieval often yields more precise results, especially for queries requiring specific or lesser-known information.\n",
    "\n",
    "### Benefits of this Approach\n",
    "\n",
    "1. **Combines Strengths of Both Retrieval and Generation**: RAG effectively merges retrieval-based methods with generative models, allowing for both precise fact-finding and natural language generation.\n",
    "\n",
    "2. **Enhanced Handling of Long-Tail Queries**: It is particularly effective for queries where specific and less frequently occurring information is needed.\n",
    "\n",
    "3. **Domain Adaptability**: The retrieval mechanism can be tuned to specific domains, ensuring that the generated responses are grounded in the most relevant and accurate domain-specific information.\n",
    "\n",
    "### Conclusion\n",
    "\n",
    "Retrieval-Augmented Generation (RAG) represents an innovative fusion of retrieval and generation techniques, significantly enhancing the capabilities of language models by grounding their outputs in relevant external information. This approach can be particularly valuable in scenarios requiring precise, context-aware responses, such as customer support, academic research, and more. As AI continues to evolve, RAG stands out as a powerful method for building more reliable and context-sensitive AI systems.\n",
    "\n",
    "### Prerequisites\n",
    "- Preferably Python 3.11\n",
    "- Jupyter Notebook or JupyterLab\n",
    "- LLM API Key\n",
    "    - You can use any llm of your choice in this notebook we have use OpenAI and Gpt-4o-mini\n",
    "\n",
    "With these steps, you can implement a basic RAG system to enhance the capabilities of language models by incorporating real-world, up-to-date information, improving their effectiveness in various applications."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up the Environment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install llama-index\n",
    "!pip install llama-index-vector-stores-qdrant \n",
    "!pip install llama-index-readers-file \n",
    "!pip install llama-index-embeddings-fastembed \n",
    "!pip install llama-index-llms-openai\n",
    "!pip install llama-index-llms-groq\n",
    "!pip install -U qdrant_client fastembed\n",
    "!pip install python-dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/adithya/miniconda3/envs/01_Basic_RAG/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.\n"
     ]
    }
   ],
   "source": [
    "# Standard library imports\n",
    "import logging\n",
    "import sys\n",
    "import os\n",
    "\n",
    "# Third-party imports\n",
    "from dotenv import load_dotenv\n",
    "from IPython.display import Markdown, display\n",
    "\n",
    "# Qdrant client import\n",
    "import qdrant_client\n",
    "\n",
    "# LlamaIndex core imports\n",
    "from llama_index.core import VectorStoreIndex, SimpleDirectoryReader\n",
    "from llama_index.core import Settings\n",
    "\n",
    "# LlamaIndex vector store import\n",
    "from llama_index.vector_stores.qdrant import QdrantVectorStore\n",
    "\n",
    "# Embedding model imports\n",
    "from llama_index.embeddings.fastembed import FastEmbedEmbedding\n",
    "from llama_index.embeddings.openai import OpenAIEmbedding\n",
    "\n",
    "# LLM import\n",
    "from llama_index.llms.openai import OpenAI\n",
    "from llama_index.llms.groq import Groq\n",
    "# Load environment variables\n",
    "load_dotenv()\n",
    "\n",
    "# Get OpenAI API key from environment variables\n",
    "OPENAI_API_KEY = os.getenv(\"OPENAI_API_KEY\")\n",
    "GROK_API_KEY = os.getenv(\"GROQ_API_KEY\")\n",
    "\n",
    "# Setting up Base LLM\n",
    "Settings.llm = OpenAI(\n",
    "    model=\"gpt-4o-mini\", temperature=0.1, max_tokens=1024, streaming=True\n",
    ")\n",
    "\n",
    "# Settings.llm = Groq(model=\"llama3-70b-8192\" , api_key=GROK_API_KEY)\n",
    "\n",
    "# Set the embedding model\n",
    "# Option 1: Use FastEmbed with BAAI/bge-base-en-v1.5 model (default)\n",
    "# Settings.embed_model = FastEmbedEmbedding(model_name=\"BAAI/bge-base-en-v1.5\")\n",
    "\n",
    "# Option 2: Use OpenAI's embedding model (commented out)\n",
    "# If you want to use OpenAI's embedding model, uncomment the following line:\n",
    "Settings.embed_model = OpenAIEmbedding(embed_batch_size=10, api_key=OPENAI_API_KEY)\n",
    "\n",
    "# Qdrant configuration (commented out)\n",
    "# If you're using Qdrant, uncomment and set these variables:\n",
    "# QDRANT_CLOUD_ENDPOINT = os.getenv(\"QDRANT_CLOUD_ENDPOINT\")\n",
    "# QDRANT_API_KEY = os.getenv(\"QDRANT_API_KEY\")\n",
    "\n",
    "# Note: Remember to add QDRANT_CLOUD_ENDPOINT and QDRANT_API_KEY to your .env file if using Qdrant Hosted version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Setting Up Observability"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !pip install arize-phoenix\n",
    "# !pip install openinference-instrumentation-llama-index\n",
    "# !pip install -U llama-index-callbacks-arize-phoenix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import phoenix as px\n",
    "\n",
    "# (session := px.launch_app()).view()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from openinference.instrumentation.langchain import LangChainInstrumentor\n",
    "# from openinference.instrumentation.llama_index import LlamaIndexInstrumentor\n",
    "# from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter\n",
    "# from opentelemetry.sdk.trace import SpanLimits, TracerProvider\n",
    "# from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n",
    "\n",
    "# endpoint = \"http://127.0.0.1:6006/v1/traces\"\n",
    "# tracer_provider = TracerProvider(span_limits=SpanLimits(max_attributes=100_000))\n",
    "# tracer_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter(endpoint)))\n",
    "\n",
    "# LlamaIndexInstrumentor().instrument(skip_dep_check=True, tracer_provider=tracer_provider)\n",
    "# # LangChainInstrumentor().instrument(skip_dep_check=True, tracer_provider=tracer_provider)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "🔃 Loading Data\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Loading files: 100%|██████████| 1/1 [00:00<00:00,  4.27file/s]\n"
     ]
    }
   ],
   "source": [
    "# lets loading the documents using SimpleDirectoryReader\n",
    "\n",
    "print(\"🔃 Loading Data\")\n",
    "\n",
    "from llama_index.core import Document\n",
    "reader = SimpleDirectoryReader(\"../data/\" , recursive=True)\n",
    "documents = reader.load_data(show_progress=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up Vector Database\n",
    "\n",
    "We will be using qDrant as the Vector database\n",
    "There are 4 ways to initialize qdrant \n",
    "\n",
    "1. Inmemory\n",
    "```python\n",
    "client = qdrant_client.QdrantClient(location=\":memory:\")\n",
    "```\n",
    "2. Disk\n",
    "```python\n",
    "client = qdrant_client.QdrantClient(path=\"./data\")\n",
    "```\n",
    "3. Self hosted or Docker\n",
    "```python\n",
    "\n",
    "client = qdrant_client.QdrantClient(\n",
    "    # url=\"http://<host>:<port>\"\n",
    "    host=\"localhost\",port=6333\n",
    ")\n",
    "```\n",
    "\n",
    "4. Qdrant cloud\n",
    "```python\n",
    "client = qdrant_client.QdrantClient(\n",
    "    url=QDRANT_CLOUD_ENDPOINT,\n",
    "    api_key=QDRANT_API_KEY,\n",
    ")\n",
    "```\n",
    "\n",
    "for this notebook we will be using qdrant cloud"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# creating a qdrant client instance\n",
    "\n",
    "client = qdrant_client.QdrantClient(\n",
    "    # you can use :memory: mode for fast and light-weight experiments,\n",
    "    # it does not require to have Qdrant deployed anywhere\n",
    "    # but requires qdrant-client >= 1.1.1\n",
    "    # location=\":memory:\"\n",
    "    # otherwise set Qdrant instance address with:\n",
    "    # url=QDRANT_CLOUD_ENDPOINT,\n",
    "    # otherwise set Qdrant instance with host and port:\n",
    "    host=\"localhost\",\n",
    "    port=6333\n",
    "    # set API KEY for Qdrant Cloud\n",
    "    # api_key=QDRANT_API_KEY,\n",
    "    # path=\"./db/\"\n",
    ")\n",
    "\n",
    "vector_store = QdrantVectorStore(client=client, collection_name=\"01_Basic_RAG\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Ingest Data into vector DB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Parsing nodes: 100%|██████████| 58/58 [00:00<00:00, 555.31it/s]\n",
      "Generating embeddings: 100%|██████████| 58/58 [00:08<00:00,  7.04it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of chunks added to vector DB : 58\n"
     ]
    }
   ],
   "source": [
    "## ingesting data into vector database\n",
    "\n",
    "## lets set up an ingestion pipeline\n",
    "\n",
    "from llama_index.core.node_parser import TokenTextSplitter\n",
    "from llama_index.core.node_parser import SentenceSplitter\n",
    "from llama_index.core.node_parser import MarkdownNodeParser\n",
    "from llama_index.core.node_parser import SemanticSplitterNodeParser\n",
    "from llama_index.core.ingestion import IngestionPipeline\n",
    "\n",
    "pipeline = IngestionPipeline(\n",
    "    transformations=[\n",
    "        # MarkdownNodeParser(include_metadata=True),\n",
    "        # TokenTextSplitter(chunk_size=500, chunk_overlap=20),\n",
    "        SentenceSplitter(chunk_size=1024, chunk_overlap=20),\n",
    "        # SemanticSplitterNodeParser(buffer_size=1, breakpoint_percentile_threshold=95 , embed_model=Settings.embed_model),\n",
    "        Settings.embed_model,\n",
    "    ],\n",
    "    vector_store=vector_store,\n",
    ")\n",
    "\n",
    "# Ingest directly into a vector db\n",
    "nodes = pipeline.run(documents=documents , show_progress=True)\n",
    "print(\"Number of chunks added to vector DB :\",len(nodes))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting Up Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = VectorStoreIndex.from_vector_store(vector_store=vector_store)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Modifying Prompts and Prompt Tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core import ChatPromptTemplate\n",
    "\n",
    "qa_prompt_str = (\n",
    "    \"Context information is below.\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"{context_str}\\n\"\n",
    "    \"---------------------\\n\"\n",
    "    \"Given the context information and not prior knowledge, \"\n",
    "    \"answer the question: {query_str}\\n\"\n",
    ")\n",
    "\n",
    "refine_prompt_str = (\n",
    "    \"We have the opportunity to refine the original answer \"\n",
    "    \"(only if needed) with some more context below.\\n\"\n",
    "    \"------------\\n\"\n",
    "    \"{context_msg}\\n\"\n",
    "    \"------------\\n\"\n",
    "    \"Given the new context, refine the original answer to better \"\n",
    "    \"answer the question: {query_str}. \"\n",
    "    \"If the context isn't useful, output the original answer again.\\n\"\n",
    "    \"Original Answer: {existing_answer}\"\n",
    ")\n",
    "\n",
    "# Text QA Prompt\n",
    "chat_text_qa_msgs = [\n",
    "    (\"system\",\"You are a AI assistant who is well versed with answering questions from the provided context\"),\n",
    "    (\"user\", qa_prompt_str),\n",
    "]\n",
    "text_qa_template = ChatPromptTemplate.from_messages(chat_text_qa_msgs)\n",
    "\n",
    "# Refine Prompt\n",
    "chat_refine_msgs = [\n",
    "    (\"system\",\"Always answer the question, even if the context isn't helpful.\",),\n",
    "    (\"user\", refine_prompt_str),\n",
    "]\n",
    "refine_template = ChatPromptTemplate.from_messages(chat_refine_msgs)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Example of Retrivers \n",
    "\n",
    "- Query Engine\n",
    "- Chat Engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "According to the context information, the encoder is composed of a stack of N = 6 identical layers."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Setting up Query Engine\n",
    "BASE_RAG_QUERY_ENGINE = index.as_query_engine(\n",
    "        similarity_top_k=5,\n",
    "        text_qa_template=text_qa_template,\n",
    "        refine_template=refine_template,)\n",
    "\n",
    "\n",
    "response = BASE_RAG_QUERY_ENGINE.query(\"How many encoders are stacked in the encoder?\")\n",
    "display(Markdown(str(response)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/markdown": [
       "The number of encoders stacked in the encoder is 6."
      ],
      "text/plain": [
       "<IPython.core.display.Markdown object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Setting up Chat Engine\n",
    "BASE_RAG_CHAT_ENGINE = index.as_chat_engine()\n",
    "\n",
    "response = BASE_RAG_CHAT_ENGINE.chat(\"How many encoders are stacked in the encoder?\")\n",
    "display(Markdown(str(response)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Simple Chat Application with RAG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List\n",
    "from llama_index.core.base.llms.types import ChatMessage, MessageRole\n",
    "\n",
    "class ChatEngineInterface:\n",
    "    def __init__(self, index):\n",
    "        self.chat_engine = index.as_chat_engine()\n",
    "        self.chat_history: List[ChatMessage] = []\n",
    "\n",
    "    def display_message(self, role: str, content: str):\n",
    "        if role == \"USER\":\n",
    "            display(Markdown(f\"**Human:** {content}\"))\n",
    "        else:\n",
    "            display(Markdown(f\"**AI:** {content}\"))\n",
    "\n",
    "    def chat(self, message: str) -> str:\n",
    "        # Create a ChatMessage for the user input\n",
    "        user_message = ChatMessage(role=MessageRole.USER, content=message)\n",
    "        self.chat_history.append(user_message)\n",
    "        \n",
    "        # Get response from the chat engine\n",
    "        response = self.chat_engine.chat(message, chat_history=self.chat_history)\n",
    "        \n",
    "        # Create a ChatMessage for the AI response\n",
    "        ai_message = ChatMessage(role=MessageRole.ASSISTANT, content=str(response))\n",
    "        self.chat_history.append(ai_message)\n",
    "        \n",
    "        # Display the conversation\n",
    "        self.display_message(\"USER\", message)\n",
    "        self.display_message(\"ASSISTANT\", str(response))\n",
    "        \n",
    "        print(\"\\n\" + \"-\"*50 + \"\\n\")  # Separator for readability\n",
    "\n",
    "        return str(response)\n",
    "\n",
    "    def get_chat_history(self) -> List[ChatMessage]:\n",
    "        return self.chat_history"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "chat_interface = ChatEngineInterface(index)\n",
    "while True:\n",
    "    user_input = input(\"You: \").strip()\n",
    "    if user_input.lower() == 'exit':\n",
    "        print(\"Thank you for chatting! Goodbye.\")\n",
    "        break\n",
    "    chat_interface.chat(user_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# To view chat history:\n",
    "history = chat_interface.get_chat_history()\n",
    "for message in history:\n",
    "    print(f\"{message.role}: {message.content}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm-venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
